{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e9dbcde",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c37e1fda-c7f1-46e7-a5d4-19fa05c36ac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from typing import List, Tuple, Any\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "from torch import Tensor\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "from optimum.onnxruntime import AutoOptimizationConfig, ORTModelForFeatureExtraction, ORTOptimizer\n",
    "from optimum.pipelines import pipeline\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78a65856",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the tokenizer and export the model to the ONNX format\n",
    "# model_id = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
    "# model_id = \"thenlper/gte-base\"\n",
    "# model_id = \"intfloat/multilingual-e5-large\"\n",
    "model_id = \"BAAI/bge-small-en-v1.5\"\n",
    "save_dir = f\"fast-{model_id.split('/')[1]}\"\n",
    "print(save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1ecf0b6-db81-4da3-b47f-e31460ccfbf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_model = AutoModel.from_pretrained(model_id)\n",
    "hf_tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "\n",
    "# The input texts can be in any language, not just English.\n",
    "# Each input text should start with \"query: \" or \"passage: \", even for non-English texts.\n",
    "# For tasks other than retrieval, you can simply use the \"query: \" prefix.\n",
    "input_texts = [\n",
    "    \"query: how much protein should a female eat\",\n",
    "    \"query: 南瓜的家常做法\",\n",
    "    \"query: भारत का राष्ट्रीय खेल कौन-सा है?\",  # Hindi text\n",
    "    \"query: భారత్ దేశంలో రాష్ట్రపతి ఎవరు?\",  # Telugu text\n",
    "    \"query: இந்தியாவின் தேசிய கோப்பை எது?\",  # Tamil text\n",
    "    \"query: ಭಾರತದಲ್ಲಿ ರಾಷ್ಟ್ರಪತಿ ಯಾರು?\",  # Kannada text\n",
    "    \"query: ഇന്ത്യയുടെ രാഷ്ട്രീയ ഗാനം എന്താണ്?\",  # Malayalam text\n",
    "]\n",
    "\n",
    "english_texts = [\n",
    "    \"India: Where the Taj Mahal meets spicy curry.\",\n",
    "    \"Machine Learning: Turning data into knowledge, one algorithm at a time.\",\n",
    "    \"Python: The language that makes programming a piece of cake.\",\n",
    "    \"fastembed: Accelerating embeddings for lightning-fast similarity search.\",\n",
    "    \"Qdrant: The ultimate tool for high-dimensional indexing and search.\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f8c761c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:\n",
    "    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)\n",
    "    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]\n",
    "\n",
    "\n",
    "def hf_embed(model_id: str, inputs: List[str]):\n",
    "    # Tokenize the input texts\n",
    "    batch_dict = hf_tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors=\"pt\")\n",
    "\n",
    "    outputs = hf_model(**batch_dict)\n",
    "    embeddings = average_pool(outputs.last_hidden_state, batch_dict[\"attention_mask\"])\n",
    "\n",
    "    # normalize embeddings\n",
    "    embeddings = F.normalize(embeddings, p=2, dim=1)\n",
    "    return embeddings.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69bb4501",
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_embed(inputs=english_texts, model_id=model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "451dbd16",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True)\n",
    "\n",
    "# Remove all existing files in the save_dir using Path.unlink()\n",
    "save_dir = Path(save_dir)\n",
    "save_dir.mkdir(parents=True, exist_ok=True)\n",
    "for p in save_dir.iterdir():\n",
    "    p.unlink()\n",
    "\n",
    "# Load the optimization configuration detailing the optimization we wish to apply\n",
    "optimization_config = AutoOptimizationConfig.O4()\n",
    "optimizer = ORTOptimizer.from_pretrained(model)\n",
    "\n",
    "optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config, use_external_data_format=True)\n",
    "model = ORTModelForFeatureExtraction.from_pretrained(save_dir)\n",
    "\n",
    "tokenizer.save_pretrained(save_dir)\n",
    "# model.save_pretrained(save_dir)\n",
    "# model.push_to_hub(\"new_path_for_directory\", repository_id=\"my-onnx-repo\", use_auth_token=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8422cddd",
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_quant_embed = pipeline(\n",
    "    \"feature-extraction\", model=model, accelerator=\"ort\", tokenizer=tokenizer, return_tensors=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51fa5775",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = onnx_quant_embed(inputs=english_texts)\n",
    "F.normalize(embeddings[4])[:, 0], english_texts[4], len(embeddings), len(english_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df405d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "def measure_pipeline_time(pipeline, input_texts: List[str], num_runs=10, **kwargs: Any) -> Tuple[float, float]:\n",
    "    \"\"\"Measures the time it takes to run the pipeline on the input texts.\"\"\"\n",
    "    times = []\n",
    "    total_chars = sum(len(text) for text in input_texts)\n",
    "    for _ in range(num_runs):\n",
    "        start_time = time.time()\n",
    "        _ = pipeline(inputs=input_texts, **kwargs)\n",
    "        end_time = time.time()\n",
    "        times.append(end_time - start_time)\n",
    "\n",
    "    mean_time = np.mean(times)\n",
    "    std_dev = np.std(times)\n",
    "    chars_per_second = total_chars / mean_time\n",
    "    return mean_time, std_dev, chars_per_second"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2d72aba5",
   "metadata": {},
   "source": [
    "# Ours"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b881152",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, chars_per_sec = measure_pipeline_time(onnx_quant_embed, input_texts)\n",
    "print(f\"Multilingual Speed: {chars_per_sec:.2f} chars/sec\")\n",
    "_, _, chars_per_sec = measure_pipeline_time(onnx_quant_embed, english_texts)\n",
    "print(f\"English Speed: {chars_per_sec:.2f} chars/sec\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "49e1daf8",
   "metadata": {},
   "source": [
    "# Original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61b3bf53",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, chars_per_sec = measure_pipeline_time(hf_embed, input_texts=input_texts, model_id=model_id)\n",
    "print(f\"Multilingual Speed: {chars_per_sec:.2f} chars/sec\")\n",
    "_, _, chars_per_sec = measure_pipeline_time(hf_embed, input_texts=english_texts, model_id=model_id)\n",
    "print(f\"English Speed: {chars_per_sec:.2f} chars/sec\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f0b7da8f-ffe7-4f58-95dd-7e9836f19328",
   "metadata": {},
   "source": [
    "# Compress & Upload\n",
    "\n",
    "## Compress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "578b1d74",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import tarfile\n",
    "\n",
    "save_dir = Path(\"../local_cache/fast-bge-small-en-v1.5\")\n",
    "\n",
    "\n",
    "def compress(directory_path):\n",
    "    directory_path = Path(directory_path)\n",
    "    assert directory_path.exists(), f\"{directory_path} does not exist\"\n",
    "    output_filename = directory_path.name + \".tar.gz\"\n",
    "    if Path(output_filename).exists():\n",
    "        print(\"We've an output file already? Manually delete that first\")\n",
    "        return output_filename\n",
    "\n",
    "    with tarfile.open(output_filename, \"w:gz\") as tar:\n",
    "        tar.add(directory_path, arcname=os.path.basename(directory_path))\n",
    "    return output_filename\n",
    "\n",
    "\n",
    "compressed_file_name = compress(save_dir)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "96cdf140-eca8-4778-9ebe-947988b4cfcb",
   "metadata": {},
   "source": [
    "## Upload to Qdrant Google Cloud Storage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1dab9595",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/homebrew/Caskroom/miniconda/base/envs/fst/lib/python3.9/site-packages/google/auth/_default.py:76: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK without a quota project. You might receive a \"quota exceeded\" or \"API not enabled\" error. See the following page for troubleshooting: https://cloud.google.com/docs/authentication/adc-troubleshooting/user-creds. \n",
      "  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "File fast-bge-small-en-v1.5.tar.gz uploaded to qdrant-fastembed.\n"
     ]
    }
   ],
   "source": [
    "from google.cloud import storage\n",
    "\n",
    "\n",
    "def upload(bucket_name, source_file_path):\n",
    "    storage_client = storage.Client(project=\"main\")\n",
    "    bucket = storage_client.bucket(bucket_name)\n",
    "    blob = bucket.blob(os.path.basename(source_file_path))\n",
    "\n",
    "    blob.upload_from_filename(source_file_path)\n",
    "\n",
    "    print(f\"File {source_file_path} uploaded to {bucket_name}.\")\n",
    "\n",
    "\n",
    "upload(\"qdrant-fastembed\", source_file_path=compressed_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "731554f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove the directory and compressed file\n",
    "!rm -rvf {save_dir}\n",
    "!rm -vf {save_dir}.tar.gz"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
